import pandas as pd
import os
import argparse


def categorize_age(age: float) -> str:
    """Return age bucket as per specification."""
    if pd.isna(age):
        return None
    try:
        age = float(age)
    except ValueError:
        return None
    if age < 18:
        return None  # exclude minors
    if 18 <= age <= 24:
        return "18-24"
    if 25 <= age <= 34:
        return "25-34"
    if 35 <= age <= 44:
        return "35-44"
    if 45 <= age <= 54:
        return "45-54"
    return "55+"


def gender_label(gender_value) -> str:
    """Map gender.x to label. 1 -> male, else female (handles NaN as female)."""
    return "male" if str(gender_value) == "1" else "female"


def main(input_csv: str, filtered_csv: str, aggregated_csv: str):
    # Read CSV lazily (low_memory=False prevents dtype guessing issues)
    df = pd.read_csv(input_csv, low_memory=False)

    # Determine which age column to use
    age_column = "age.x" if "age.x" in df.columns else "years0.x" if "years0.x" in df.columns else None
    if age_column is None:
        raise ValueError("Neither 'age.x' nor 'years0.x' present in the input CSV. Cannot bucket ages.")

    # Determine which response column to use
    response_column = None
    for col in ["mean_response", "mean_score", "response", "response.x", "response.y"]:
        if col in df.columns:
            response_column = col
            break
    if response_column is None:
        raise ValueError("No response column found (expected one of mean_response, mean_score, response, response.x, response.y)")

    # Create group column
    df["age_bucket"] = df[age_column].apply(categorize_age)
    df["gender_label"] = df["gender.x"].apply(gender_label)
    df["group"] = df.apply(lambda row: f"{row['age_bucket']}_{row['gender_label']}" if pd.notna(row['age_bucket']) else None, axis=1)

    # Drop rows without a group (e.g., minors or missing age)
    df_filtered = df.dropna(subset=["group"]).copy()

    # Keep only useful columns for filtered CSV
    filtered_cols = ["website", "participant_id", "group", response_column]
    df_filtered_simple = df_filtered[filtered_cols]
    df_filtered_simple.rename(columns={response_column: "response"}, inplace=True)

    # Compute aggregated mean per website and group
    agg_df = df_filtered_simple.groupby(["website", "group"], as_index=False)["response"].mean()
    agg_df.rename(columns={"response": "mean_response"}, inplace=True)

    # Save outputs
    df_filtered_simple.to_csv(filtered_csv, index=False)
    agg_df.to_csv(aggregated_csv, index=False)

    print(f"Filtered data with group column written to: {filtered_csv}")
    print(f"Aggregated mean responses written to: {aggregated_csv}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process human correlation CSV and create group-based aggregates.")
    parser.add_argument("--input", "-i", default="path/to/ae_only_unambiguous_1000.csv", help="Path to input CSV file.")
    parser.add_argument("--filtered", "-f", default="path/to/ae_only_unambiguous_with_group.csv", help="Path to output filtered CSV file with group column.")
    parser.add_argument("--aggregated", "-a", default="path/to/website_group_mean_response.csv", help="Path to output aggregated CSV file.")

    args = parser.parse_args()
    main(args.input, args.filtered, args.aggregated)
